use core::alloc::Layout;
use core::mem::{self, MaybeUninit};
use core::ops::{Deref, DerefMut};
use core::slice;
use hipool::{Allocator, PoolAlloc, Error};

pub struct Stack<'a, T> {
    data: &'a mut [MaybeUninit<T>],
    len: usize,
}

unsafe impl<T: Send> Send for Stack<'_, T> {}
unsafe impl<T: Sync> Sync for Stack<'_, T> {}

impl<T> Stack<'_, T> {
    pub fn push(&mut self, val: T) -> bool {
        if self.len < self.data.len() {
            self.data[self.len].write(val);
            self.len += 1;
            true
        } else {
            false
        }
    }

    pub fn peek(&self) -> Option<&T> {
        if self.len > 0 {
            Some(unsafe { self.data[self.len - 1].assume_init_ref() })
        } else {
            None
        }
    }

    pub fn pop(&mut self) -> Option<T> {
        if self.len > 0 {
            self.len -= 1;
            unsafe { Some(self.data[self.len].assume_init_read()) }
        } else {
            None
        }
    }

    pub fn len(&self) -> usize {
        self.len
    }
}

impl<T> Drop for Stack<'_, T> {
    fn drop(&mut self) {
        if mem::needs_drop::<T>() {
            while self.pop().is_some() {}
        }
    }
}

#[cfg(test)]
impl<T> Stack<'_, T> {
    pub fn is_empty(&self) -> bool {
        self.len == 0
    }

    pub fn capacity(&self) -> usize {
        self.data.len()
    }
}

#[cfg(test)]
impl<'a, T> Stack<'a, T> {
    pub fn new(data: &'a mut [MaybeUninit<T>]) -> Self {
        Self { data, len: 0 }
    }
}

#[cfg(test)]
impl<'a, T> DynStack<'a, T> {
    pub fn new(len: usize) -> Result<DynStack<'a, T, PoolAlloc>, Error> {
        Self::new_in(&PoolAlloc, len)
    }
}

impl<'a, T, A: Allocator> DynStack<'a, T, A> {
    pub fn new_in(alloc: &'a A, len: usize) -> Result<Self, Error> {
        let layout = Layout::array::<T>(len).unwrap();
        let ptr = unsafe { alloc.allocate(layout)?.cast::<MaybeUninit<T>>().as_ptr() };
        let data = unsafe { slice::from_raw_parts_mut(ptr, len) };
        Ok(Self {
            stack: Stack { data, len: 0 },
            alloc,
        })
    }
}

pub struct DynStack<'a, T, A: Allocator = PoolAlloc> {
    stack: Stack<'a, T>,
    alloc: &'a A,
}

unsafe impl<T, A: Sync + Allocator> Send for DynStack<'_, T, A> {}
unsafe impl<T: Sync, A: Allocator> Sync for DynStack<'_, T, A> {}

impl<'a, T, A: Allocator> Deref for DynStack<'a, T, A> {
    type Target = Stack<'a, T>;
    fn deref(&self) -> &Self::Target {
        &self.stack
    }
}

impl<T, A: Allocator> DerefMut for DynStack<'_, T, A> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.stack
    }
}

impl<T, A: Allocator> Drop for DynStack<'_, T, A> {
    fn drop(&mut self) {
        let layout = Layout::array::<T>(self.len()).unwrap();
        let ptr = self.stack.data.as_ptr() as *const u8;
        unsafe {
            let slice = slice::from_raw_parts(ptr, layout.size());
            self.alloc.deallocate(slice.into(), layout);
        }
    }
}

#[cfg(test)]
mod test {
    use super::{DynStack, Stack};
    use core::mem::MaybeUninit;

    static mut DROP: usize = 0;
    #[allow(dead_code)]
    struct Foo {
        val: usize,
    }
    impl Drop for Foo {
        fn drop(&mut self) {
            unsafe {
                DROP += 1;
            }
        }
    }

    #[test]
    fn test_stack_drop() {
        let foos: MaybeUninit<[MaybeUninit<Foo>; 10]> = MaybeUninit::uninit();
        let mut foos = unsafe { foos.assume_init() };
        unsafe { DROP = 0 };
        {
            let len = foos.len();
            let mut stack = Stack::new(&mut foos);
            for val in 0..len {
                assert!(stack.push(Foo { val }));
            }
        }
        unsafe {
            assert_eq!(DROP, foos.len());
        }
    }

    #[test]
    fn test_stack_pop() {
        let mut stack = DynStack::new(10).unwrap();
        for i in 0..stack.capacity() {
            assert!(stack.push(i));
        }
        assert!(!stack.push(0));
        assert_eq!(stack.len(), stack.capacity());
        for i in (0..stack.capacity()).rev() {
            assert_eq!(stack.pop(), Some(i));
        }
        assert_eq!(stack.len(), 0);
        assert!(stack.is_empty());
    }
}
